/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved.
 * Description: 集合通信reducescatter算子融合头文件
 * Author: ningshuqi
 * Create: 2023-5-18
 */

#ifndef UNTITLED4_HCOM_REDUCESCATTER_FUSION_H
#define UNTITLED4_HCOM_REDUCESCATTER_FUSION_H

#include "hccl/base.h"
#include "hcom_alltoallvc_fusion.h"
#include "hcom_allgather_fusion.h"
#include "common/optimizer/graph_optimizer.h"
#include "common/optimizer/graph_optimizer_types.h"
#include "graph/compute_graph.h"
#include "op_fusion_base_pub.h"

namespace hccl {
// 记录reducescatter融合后新增的所有节点
using ReduceScatterFusionNodesInfo = struct reducescatterFusionNodesInfo {
    std::vector<ge::NodePtr> sendDataSplitVs;
    ge::NodePtr reducescatterFusionNodePtr;
    reducescatterFusionNodesInfo()
        : sendDataSplitVs(0), reducescatterFusionNodePtr(nullptr)
    {
    }
};

class HcomReduceScatterFusion : public OpFusionBase {
public:
    HcomReduceScatterFusion();
    ~HcomReduceScatterFusion() override;
    HcclResult Run(ge::ComputeGraph& graph) override;
private:
    HcclResult GetFusionOps(ge::ComputeGraph& graph, FusionInfos& fusionOps);
    HcclResult GetFusionOpInfo(ge::NodePtr& nodePtr, FusionInfos& fusionOps);
    HcclResult GetFusionOption(const ge::NodePtr& nodePtr, FusionOption& fusionOption);
    HcclResult GenerateFusionLabel(const FusionOption& fusionOption, std::string& fusionLabel);
    HcclResult FuseOps(ge::ComputeGraph& graph, FusionSection& fusionSection);
    HcclResult RunFusionOpsReduceScatter(ge::ComputeGraph& graph, std::vector<ge::NodePtr>& fusionOps);
    // 记录每个reducescatter算子的输入数据边、输入控制边、输出数据边、输出控制边, 保存后删除原reducescatter节点
    HcclResult RemoveOpsEdges(ge::ComputeGraph& graph, std::vector<ge::NodePtr>& fusionOps,
                              std::vector<CommonNodeInfo>& nodeInfos, ge::OpDescPtr& fusedOp);
    HcclResult GetPeerOutDataToInData(std::vector<ge::OutDataAnchorPtr>& peerOutDataAnchorVec,
        ge::NodePtr& srcNodePtr);
    HcclResult GetPeerOutDataToInControl(vector<ge::OutDataAnchorPtr>& peerOutDataToInControlVec,
                                         ge::NodePtr& srcNodePtr);
    HcclResult GetPeerOutControlToInControl(vector<ge::OutControlAnchorPtr>& peerOutControlToInControlVec,
                                            ge::NodePtr& srcNodePtr);
    HcclResult GetPeerAnchorFromOutData(std::vector<ge::InDataAnchorPtr>& peerInDataFromOutDataVec,
        std::vector<ge::InControlAnchorPtr>& peerInControlFromOutDataVec, ge::NodePtr& srcNodePtr);
    HcclResult GetPeerInDataAnchorFromOutData(std::vector<ge::InDataAnchorPtr>& peerInDataFromOutDataVec,
                                              ge::OutDataAnchorPtr outDataAnchor, ge::NodePtr& srcNodePtr);
    HcclResult GetPeerInControlAnchorFromOutData(std::vector<ge::InControlAnchorPtr>& peerInControlFromOutDataVec,
                                                 ge::OutDataAnchorPtr outDataAnchor, ge::NodePtr& srcNodePtr);
    HcclResult GetPeerInControlFromOutControl(vector<ge::InControlAnchorPtr>& peerInControlFromOutControlVec,
                                              ge::NodePtr& srcNodePtr);
    HcclResult GetReduceScatterOpInfo(s32& rankSize, string& group, std::string& nodeName, ge::NodePtr& srcNodePtr);

    // 创建节点, 将节点添加到graph中, 并添加数据边: peerOutDataAnchor, peerInDataAnchor
    HcclResult AddFusionNode(ge::ComputeGraph& graph, std::vector<CommonNodeInfo>& nodeInfos,
                             ReduceScatterFusionNodesInfo& fusionNodesInfo, ge::OpDescPtr& fusedOp);
    HcclResult AddSendDataSplitV(ge::ComputeGraph& graph, std::vector<CommonNodeInfo>& nodeInfos,
                                 ReduceScatterFusionNodesInfo& fusionNodesInfo);
    HcclResult AddReduceScatterNode(ge::ComputeGraph& graph, std::vector<CommonNodeInfo>& nodeInfos,
                                ReduceScatterFusionNodesInfo& fusionNodesInfo, ge::OpDescPtr& fusedOp);

    // 恢复控制边: peerOutDataToInControl, peerOutControlAnchor, peerInControlFromOutData, peerInControlAnchor
    HcclResult RestoreOpsEdges(std::vector<CommonNodeInfo>& nodeInfos,
                               ReduceScatterFusionNodesInfo& fusionNodesInfo);
    HcclResult AddOpsEdge(const ge::OutDataAnchorPtr &src, const ge::InDataAnchorPtr &dst);
    HcclResult CreateConstNode(ge::NodePtr& nodePtr, std::string nodeName, std::vector<int32_t> nodeValue,
                               std::vector<int64_t> dim, ge::ComputeGraph& graph);
    HcclResult CreateSplitVNode(SplitVNodeInfo& splitvNodeInfo, ge::ComputeGraph& graph);
    HcclResult GetFusionOpsSlices(FusionInfos& fusionInfos, FusionInfos& fusionInfosTemp);
};
}
#endif
